#!/usr/bin/env python3

import csv
import os
import socket
import subprocess
import sys
import tempfile
import threading
import traceback
import urllib.request
from http.server import BaseHTTPRequestHandler, HTTPServer
from io import StringIO
from socketserver import ThreadingMixIn
from urllib.parse import urljoin


def is_ipv6(host):
    try:
        socket.inet_aton(host)
        return False
    except:
        return True


def get_local_port(host, ipv6):
    if ipv6:
        family = socket.AF_INET6
    else:
        family = socket.AF_INET

    with socket.socket(family) as fd:
        fd.bind((host, 0))
        return fd.getsockname()[1]


CLICKHOUSE_HOST = os.environ.get("CLICKHOUSE_HOST", "127.0.0.1")
CLICKHOUSE_PORT_HTTP = os.environ.get("CLICKHOUSE_PORT_HTTP", "8123")

#####################################################################################
# This test starts an HTTP server and serves data to clickhouse url-engine based table.
# In order for it to work ip+port of http server (given below) should be
# accessible from clickhouse server.
#####################################################################################

# IP-address of this host accessible from the outside world. Get the first one
HTTP_SERVER_HOST = (
    subprocess.check_output(["hostname", "-i"]).decode("utf-8").strip().split()[0]
)
IS_IPV6 = is_ipv6(HTTP_SERVER_HOST)
HTTP_SERVER_PORT = get_local_port(HTTP_SERVER_HOST, IS_IPV6)

# IP address and port of the HTTP server started from this script.
HTTP_SERVER_ADDRESS = (HTTP_SERVER_HOST, HTTP_SERVER_PORT)
if IS_IPV6:
    HTTP_SERVER_URL_STR = (
        "http://"
        + f"[{str(HTTP_SERVER_ADDRESS[0])}]:{str(HTTP_SERVER_ADDRESS[1])}"
        + "/"
    )
else:
    HTTP_SERVER_URL_STR = (
        "http://" + f"{str(HTTP_SERVER_ADDRESS[0])}:{str(HTTP_SERVER_ADDRESS[1])}" + "/"
    )

CSV_DATA = os.path.join(
    tempfile._get_default_tempdir(), next(tempfile._get_candidate_names())
)


def get_ch_answer(query):
    host = CLICKHOUSE_HOST
    if IS_IPV6:
        host = f"[{host}]"

    url = os.environ.get(
        "CLICKHOUSE_URL",
        "http://{host}:{port}".format(host=CLICKHOUSE_HOST, port=CLICKHOUSE_PORT_HTTP),
    )
    return urllib.request.urlopen(url, data=query.encode()).read().decode()


def check_answers(query, answer):
    ch_answer = get_ch_answer(query)
    if ch_answer.strip() != answer.strip():
        print("FAIL on query:", query, file=sys.stderr)
        print("Expected answer:", answer, file=sys.stderr)
        print("Fetched answer :", ch_answer, file=sys.stderr)
        raise Exception("Fail on query")


class CSVHTTPServer(BaseHTTPRequestHandler):
    def _set_headers(self):
        self.send_response(200)
        self.send_header("Content-type", "text/csv")
        self.end_headers()

    def do_GET(self):
        self._set_headers()
        with open(CSV_DATA, "r") as fl:
            reader = csv.reader(fl, delimiter=",")
            for row in reader:
                self.wfile.write((", ".join(row) + "\n").encode())
        return

    def do_HEAD(self):
        self._set_headers()
        return

    def read_chunk(self):
        msg = ""
        while True:
            sym = self.rfile.read(1)
            if sym == "":
                break
            msg += sym.decode("utf-8")
            if msg.endswith("\r\n"):
                break
        length = int(msg[:-2], 16)
        if length == 0:
            return ""
        content = self.rfile.read(length)
        self.rfile.read(2)  # read sep \r\n
        return content.decode("utf-8")

    def do_POST(self):
        data = ""
        while True:
            chunk = self.read_chunk()
            if not chunk:
                break
            data += chunk
        with StringIO(data) as fl:
            reader = csv.reader(fl, delimiter=",")
            with open(CSV_DATA, "a") as d:
                for row in reader:
                    d.write(",".join(row) + "\n")
        self._set_headers()
        self.wfile.write(b"ok")

    def log_message(self, format, *args):
        return


class HTTPServerV6(HTTPServer):
    address_family = socket.AF_INET6


class ThreadedHTTPServer(ThreadingMixIn, HTTPServer):
    pass


class ThreadedHTTPServerV6(ThreadingMixIn, HTTPServerV6):
    pass


def start_server():
    if IS_IPV6:
        httpd = ThreadedHTTPServerV6(HTTP_SERVER_ADDRESS, CSVHTTPServer)
    else:
        httpd = ThreadedHTTPServer(HTTP_SERVER_ADDRESS, CSVHTTPServer)

    t = threading.Thread(target=httpd.serve_forever)
    return t, httpd


# test section


def test_select(
    table_name="",
    schema="str String,numuint UInt32,numint Int32,double Float64",
    requests=[],
    answers=[],
    test_data="",
    res_path="",
):
    with open(CSV_DATA, "w") as f:  # clear file
        f.write("")

    if test_data:
        with open(CSV_DATA, "w") as f:
            f.write(test_data + "\n")

    if table_name:
        get_ch_answer("drop table if exists {}".format(table_name))
        get_ch_answer(
            "create table {} ({}) engine=URL('{}', 'CSV')".format(
                table_name, schema, HTTP_SERVER_URL_STR
            )
        )

    for i in range(len(requests)):
        tbl = table_name
        if not tbl:
            tbl = "url('{addr}', 'CSV', '{schema}')".format(
                addr=urljoin(HTTP_SERVER_URL_STR, res_path), schema=schema
            )
        check_answers(requests[i].format(tbl=tbl), answers[i])

    if table_name:
        get_ch_answer("drop table if exists {}".format(table_name))


def test_insert(
    table_name="",
    schema="str String,numuint UInt32,numint Int32,double Float64",
    requests_insert=[],
    requests_select=[],
    answers=[],
):
    with open(CSV_DATA, "w") as f:  # flush test file
        f.write("")

    if table_name:
        get_ch_answer("drop table if exists {}".format(table_name))
        get_ch_answer(
            "create table {} ({}) engine=URL('{}', 'CSV')".format(
                table_name, schema, HTTP_SERVER_URL_STR
            )
        )

    for req in requests_insert:
        tbl = table_name
        if not tbl:
            tbl = "table function url('{addr}', 'CSV', '{schema}')".format(
                addr=HTTP_SERVER_URL_STR, schema=schema
            )
        get_ch_answer(req.format(tbl=tbl))

    for i in range(len(requests_select)):
        tbl = table_name
        if not tbl:
            tbl = "url('{addr}', 'CSV', '{schema}')".format(
                addr=HTTP_SERVER_URL_STR, schema=schema
            )
        check_answers(requests_select[i].format(tbl=tbl), answers[i])

    if table_name:
        get_ch_answer("drop table if exists {}".format(table_name))


def main():
    test_data = "Hello,2,-2,7.7\nWorld,2,-5,8.8"
    select_only_requests = {
        "select str,numuint,numint,double from {tbl}": test_data.replace(",", "\t"),
        "select numuint, count(*) from {tbl} group by numuint": "2\t2",
        "select str,numuint,numint,double from {tbl} limit 1": test_data.split("\n")[
            0
        ].replace(",", "\t"),
    }

    insert_requests = [
        "insert into {tbl} values('Hello',10,-2,7.7)('World',10,-5,7.7)",
        "insert into {tbl} select 'Buy', number, 9-number, 9.9 from system.numbers limit 10",
    ]

    select_requests = {
        "select distinct numuint from {tbl} order by numuint": "\n".join(
            [str(i) for i in range(11)]
        ),
        "select count(*) from {tbl}": "12",
        "select double, count(*) from {tbl} group by double order by double": "7.7\t2\n9.9\t10",
    }

    pathname = CSV_DATA
    filename = os.path.basename(CSV_DATA)
    select_virtual_requests = {
        "select _path from {tbl}": "\n".join(pathname for _ in range(2)),
        "select _file from {tbl}": "\n".join(filename for _ in range(2)),
        "select _file, from {tbl} order by _path": "\n".join(
            filename for _ in range(2)
        ),
        "select _path, _file from {tbl}": "\n".join(
            f"{pathname}\t{filename}" for _ in range(2)
        ),
        "select _path, count(*) from {tbl} group by _path": f"{pathname}\t2",
    }

    t, httpd = start_server()
    t.start()
    # test table with url engine
    test_select(
        table_name="test_table_select",
        requests=list(select_only_requests.keys()),
        answers=list(select_only_requests.values()),
        test_data=test_data,
    )
    # test table function url
    test_select(
        requests=list(select_only_requests.keys()),
        answers=list(select_only_requests.values()),
        test_data=test_data,
    )
    # test table function url for virtual column
    test_select(
        requests=list(select_virtual_requests.keys()),
        answers=list(select_virtual_requests.values()),
        test_data=test_data,
        res_path=CSV_DATA,
    )

    # test insert into table with url engine
    test_insert(
        table_name="test_table_insert",
        requests_insert=insert_requests,
        requests_select=list(select_requests.keys()),
        answers=list(select_requests.values()),
    )
    # test insert into table function url
    test_insert(
        requests_insert=insert_requests,
        requests_select=list(select_requests.keys()),
        answers=list(select_requests.values()),
    )

    httpd.shutdown()
    t.join()
    print("PASSED")


if __name__ == "__main__":
    try:
        main()
    except Exception as ex:
        exc_type, exc_value, exc_traceback = sys.exc_info()
        traceback.print_tb(exc_traceback, file=sys.stderr)
        print(ex, file=sys.stderr)
        sys.stderr.flush()

        os._exit(1)
